package com.jstarcraft.ai.jsat.classifiers.linear.kernelized;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.BaseUpdateableClassifier;
import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.calibration.BinaryScoreClassifier;
import com.jstarcraft.ai.jsat.distributions.Distribution;
import com.jstarcraft.ai.jsat.distributions.LogUniform;
import com.jstarcraft.ai.jsat.distributions.kernels.KernelTrick;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.lossfunctions.HingeLoss;
import com.jstarcraft.ai.jsat.lossfunctions.LossC;
import com.jstarcraft.ai.jsat.parameters.Parameter.ParameterHolder;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;

/**
 * Bounded Online Gradient Descent (BOGD) is a kernel learning algorithm that
 * uses a bounded number of support vectors. Once the maximum number of support
 * vectors is reached, old vectors are dropped either in a uniform random
 * fashion, or weighted by the kernel function and the current coefficient for
 * the vector. The later is the default method and is referred to as BOGD++.<br>
 * <br>
 * See: Zhao, P., Wang, J., Wu, P., Jin, R.,&amp;Hoi, S. C. H. (2012). <i>Fast
 * Bounded Online Gradient Descent Algorithms for Scalable Kernel-Based Online
 * Learning</i>. In Proceedings of the 29th International Conference on Machine
 * Learning (pp. 169–176). Learning; Machine Learning. Retrieved from
 * <a href="http://arxiv.org/abs/1206.4633">here</a>
 * 
 * @author Edward Raff
 */
public class BOGD extends BaseUpdateableClassifier implements BinaryScoreClassifier, Parameterized {

    private static final long serialVersionUID = -3547832514098781996L;
    @ParameterHolder
    private KernelTrick k;
    private int budget;
    private double eta;
    private double reg;
    private double maxCoeff;
    private LossC lossC;

    private boolean uniformSampling;

    private Random rand;
    private List<Vec> vecs;
    /**
     * Stores the sqrt of each support vector's kernel product with itself
     */
    private DoubleArrayList selfK;
    private DoubleArrayList alphas;
    private DoubleArrayList accelCache;
    /**
     * Cache of values used for BOGD++ sampling
     */
    private double[] dist;

    /**
     * Creates a new BOGD++ learner using the {@link HingeLoss}
     * 
     * @param k        the kernel trick to use
     * @param budget   the budget for support vectors to allow
     * @param eta      the learning rate to use
     * @param reg      the regularization parameter
     * @param maxCoeff the maximum support vector coefficient to allow
     */
    public BOGD(KernelTrick k, int budget, double eta, double reg, double maxCoeff) {
        this(k, budget, eta, reg, maxCoeff, new HingeLoss());
    }

    /**
     * Creates a new BOGD++ learner
     * 
     * @param k        the kernel trick to use
     * @param budget   the budget for support vectors to allow
     * @param eta      the learning rate to use
     * @param reg      the regularization parameter
     * @param maxCoeff the maximum support vector coefficient to allow
     * @param lossC    the loss function to use
     */
    public BOGD(KernelTrick k, int budget, double eta, double reg, double maxCoeff, LossC lossC) {
        setKernel(k);
        setBudget(budget);
        setEta(eta);
        setRegularization(reg);
        setMaxCoeff(maxCoeff);
        this.lossC = lossC;
        setUniformSampling(false);
    }

    /**
     * Copy constructor
     * 
     * @param toCopy the object to make a copy of
     */
    public BOGD(BOGD toCopy) {
        this.k = toCopy.k.clone();
        this.budget = toCopy.budget;
        this.eta = toCopy.eta;
        this.reg = toCopy.reg;
        this.maxCoeff = toCopy.maxCoeff;
        this.lossC = toCopy.lossC.clone();
        this.uniformSampling = toCopy.uniformSampling;
        this.rand = RandomUtil.getRandom();
        if (toCopy.vecs != null) {
            this.vecs = new ArrayList<Vec>(budget);
            for (Vec v : toCopy.vecs)
                this.vecs.add(v.clone());
            this.selfK = new DoubleArrayList(toCopy.selfK);
            this.alphas = new DoubleArrayList(toCopy.alphas);
        }
        if (toCopy.accelCache != null)
            this.accelCache = new DoubleArrayList(toCopy.accelCache);
        if (toCopy.dist != null)
            this.dist = Arrays.copyOf(toCopy.dist, toCopy.dist.length);
    }

    /**
     * Sets the regularization parameter used for training. The original paper
     * suggests values in the range 2<sup>x</sup>/T<sup>2</sup> for <i>x</i> &isin;
     * {-3, -2, -1, 0, 1, 2, 3} where <i>T</i> is the number of data instances that
     * will be trained on
     * 
     * @param regularization the positive regularization parameter to use.
     */
    public void setRegularization(double regularization) {
        if (regularization <= 0 || Double.isNaN(regularization) || Double.isInfinite(regularization))
            throw new IllegalArgumentException("Regularization must be positive, not " + regularization);
        this.reg = regularization;
    }

    /**
     * Returns the regularization parameter used
     * 
     * @return the regularization parameter used
     */
    public double getRegularization() {
        return reg;
    }

    /**
     * Sets the learning rate to use for training. The original paper suggests
     * values in the range 2<sup>x</sup> for <i>x</i> &isin; {-3, -2, -1, 0, 1, 2,
     * 3}
     * 
     * @param eta the positive learning rate to use
     */
    public void setEta(double eta) {
        if (eta <= 0 || Double.isNaN(eta) || Double.isInfinite(eta))
            throw new IllegalArgumentException("Eta must be positive, not " + eta);
        this.eta = eta;
    }

    /**
     * Returns the learning rate in use
     * 
     * @return the learning rate in use
     */
    public double getEta() {
        return eta;
    }

    /**
     * Sets the maximum allowed value for any support vector allowed. The original
     * paper suggests values in the range 2<sup>x</sup> for <i>x</i> &isin; {0, 1,
     * 2, 3, 4}
     * 
     * @param maxCoeff the maximum value for any support vector
     */
    public void setMaxCoeff(double maxCoeff) {
        if (maxCoeff <= 0 || Double.isNaN(maxCoeff) || Double.isInfinite(maxCoeff))
            throw new IllegalArgumentException("MaxCoeff must be positive, not " + maxCoeff);
        this.maxCoeff = maxCoeff;
    }

    /**
     * Returns the maximum allowed value for any support vector
     * 
     * @return the maximum allowed value for any support vector
     */
    public double getMaxCoeff() {
        return maxCoeff;
    }

    /**
     * Sets the budget for support vectors
     * 
     * @param budget the allowed budget for support vectors
     */
    public void setBudget(int budget) {
        if (budget <= 0)
            throw new IllegalArgumentException("Budget must be positive, not " + budget);
        this.budget = budget;
    }

    /**
     * Returns the maximum number of allowed support vectors
     * 
     * @return the maximum number of allowed support vectors
     */
    public int getBudget() {
        return budget;
    }

    /**
     * Sets the kernel to use
     * 
     * @param k the kernel to use
     */
    public void setKernel(KernelTrick k) {
        this.k = k;
    }

    /**
     * Returns the kernel to use
     * 
     * @return the kernel to use
     */
    public KernelTrick getKernel() {
        return k;
    }

    /**
     * Sets whether or not support vectors should be removed by uniform sampling or
     * not. The default is {@code false}, which corresponds to BOGD++.
     * 
     * @param uniformSampling {@code true} to use uniform sampling, {@code false}
     *                        otherwise.
     */
    public void setUniformSampling(boolean uniformSampling) {
        this.uniformSampling = uniformSampling;
    }

    /**
     * Returns {@code true } is uniform sampling is in use, or {@code false} if the
     * BOGD++ sampling procedure is in use
     * 
     * @return {@code true } is uniform sampling is in use, or {@code false} if the
     *         BOGD++ sampling procedure is in use
     */
    public boolean isUniformSampling() {
        return uniformSampling;
    }

    @Override
    public BOGD clone() {
        return new BOGD(this);
    }

    @Override
    public void setUp(CategoricalData[] categoricalAttributes, int numericAttributes, CategoricalData predicting) {
        vecs = new ArrayList<Vec>(budget);
        alphas = new DoubleArrayList(budget);
        selfK = new DoubleArrayList(budget);
        if (k.supportsAcceleration())
            accelCache = new DoubleArrayList(budget);
        else
            accelCache = null;
        if (!uniformSampling)
            dist = new double[budget];
        rand = RandomUtil.getRandom();
    }

    @Override
    public double getScore(DataPoint dp) {
        Vec x = dp.getNumericalValues();
        return score(x, k.getQueryInfo(x));
    }

    private double score(Vec x, DoubleList qi) {
        return k.evalSum(vecs, accelCache, alphas.elements(), x, qi, 0, alphas.size());
    }

    @Override
    public void update(DataPoint dataPoint, double weight, int targetClass) {
        final Vec x_t = dataPoint.getNumericalValues();
        final double y_t = targetClass * 2 - 1;

        final DoubleList qi = k.getQueryInfo(x_t);
        final double score = score(x_t, qi);
        final double lossD = lossC.getDeriv(score, y_t);

        if (lossD == 0) {
            new DenseVector(alphas.elements(), 0, alphas.size()).mutableMultiply(1 - eta * reg);
        } else {
            if (vecs.size() < budget) {
                new DenseVector(alphas.elements(), 0, alphas.size()).mutableMultiply(1 - eta * reg);
                alphas.add(-eta * lossD);
                selfK.add(Math.sqrt(k.eval(0, 0, Arrays.asList(x_t), qi)));
                if (k.supportsAcceleration())
                    accelCache.addAll(qi);
                vecs.add(x_t);
            } else// budget maintinance
            {
                final int toRemove;
                final double normalize;
                if (uniformSampling) {
                    toRemove = rand.nextInt(budget);
                    normalize = 1;
                } else {
                    double s = 0;
                    for (int i = 0; i < budget; i++)
                        s += Math.abs(alphas.getDouble(i)) * selfK.getDouble(i);
                    s = (budget - 1) / s;
                    final double target = rand.nextDouble();
                    double cur = 0;
                    int i = -1;
                    while (cur < target) {
                        i++;
                        cur += dist[i] = 1 - s * alphas.getDouble(i) * selfK.getDouble(i);
                    }
                    toRemove = i++;
                    while (i < budget)
                        cur += dist[i] = 1 - s * alphas.getDouble(i) * selfK.getDouble(i++);
                    normalize = cur;
                }

                for (int i = 0; i < budget; i++) {
                    if (i == toRemove)
                        continue;
                    double alpha_i = alphas.getDouble(i);
                    double sign = Math.signum(alpha_i);
                    alpha_i = Math.abs(alpha_i);
                    double tmp = uniformSampling ? 1.0 / budget : dist[i] / normalize;
                    alphas.set(i, sign * Math.min((1 - reg * eta) / (1 - tmp) * alpha_i, maxCoeff * eta));
                }

                // Remove old point
                if (k.supportsAcceleration()) {
                    int catToRet = accelCache.size() / budget;
                    for (int i = 0; i < catToRet; i++)
                        accelCache.removeDouble(toRemove * catToRet);
                }
                alphas.removeDouble(toRemove);
                vecs.remove(toRemove);
                selfK.removeDouble(toRemove);

                // Add new point
                alphas.add(-eta * lossD);
                selfK.add(Math.sqrt(k.eval(0, 0, Arrays.asList(x_t), qi)));
                accelCache.addAll(qi);
                vecs.add(x_t);
            }
        }
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        Vec x = data.getNumericalValues();
        return lossC.getClassification(score(x, k.getQueryInfo(x)));
    }

    @Override
    public boolean supportsWeightedData() {
        return false;
    }

    /**
     * Guesses the distribution to use for the Regularization parameter
     *
     * @param d the dataset to get the guess for
     * @return the guess for the Regularization parameter
     * @see #setRegularization(double)
     */
    public static Distribution guessRegularization(DataSet d) {
        double T2 = d.size();
        T2 *= T2;

        return new LogUniform(Math.pow(2, -3) / T2, Math.pow(2, 3) / T2);
    }

    /**
     * Guesses the distribution to use for the &eta; parameter
     *
     * @param d the dataset to get the guess for
     * @return the guess for the &eta; parameter
     * @see #setEta(double)
     */
    public static Distribution guessEta(DataSet d) {
        return new LogUniform(Math.pow(2, -3), Math.pow(2, 3));
    }

    /**
     * Guesses the distribution to use for the MaxCoeff parameter
     *
     * @param d the dataset to get the guess for
     * @return the guess for the MaxCoeff parameter
     * @see #setMaxCoeff(double) (double)
     */
    public static Distribution guessMaxCoeff(DataSet d) {
        return new LogUniform(Math.pow(2, 0), Math.pow(2, 4));
    }
}
